Skip to content

[NPU] Add fused_linear_cross_entropy operator#1164

Open
lowdy1 wants to merge 2 commits intolinkedin:mainfrom
lowdy1:fused_ce
Open

[NPU] Add fused_linear_cross_entropy operator#1164
lowdy1 wants to merge 2 commits intolinkedin:mainfrom
lowdy1:fused_ce

Conversation

@lowdy1
Copy link
Copy Markdown
Contributor

@lowdy1 lowdy1 commented Mar 25, 2026

Summary

To address the UB overflow issue observed in the benchmark, we introduced an operator with an NPU-friendly implementation of fused linear cross entropy. This fused operator relies on several underlying operations (e.g., large matrix multiplication, softmax, and cross entropy), so its current benchmark performance is not yet optimal. Further optimization may be needed.

Testing Done

Device: Atlas A2
python -m pytest ./test/transformers/test_fused_linear_cross_entropy.py
image

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@lowdy1
Copy link
Copy Markdown
Contributor Author

lowdy1 commented Mar 25, 2026

**************************************
     BENCHMARKING SPEED for FUSED_LINEAR_CROSS_ENTROPY
**************************************
********** Benchmark Data **********
[
  {
    "kernel_name": "fused_linear_cross_entropy",
    "kernel_provider": "liger",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "BT",
    "x_label": "B x T",
    "x_values": [
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      916.869384765625,
      970.54833984375,
      1137.097412109375
    ],
    "y_values_20": [
      916.869384765625,
      970.54833984375,
      1137.097412109375
    ],
    "y_values_80": [
      916.869384765625,
      970.54833984375,
      1137.097412109375
    ],
    "timestamp": "2026-04-01 07:27:02",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"H\": 4096, \"V\": 128256, \"mode\": \"forward\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.7.0"
  },
  {
    "kernel_name": "fused_linear_cross_entropy",
    "kernel_provider": "liger-fp32-accum",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "BT",
    "x_label": "B x T",
    "x_values": [
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      825.244384765625,
      879.1065063476562,
      1045.8411865234375
    ],
    "y_values_20": [
      825.244384765625,
      879.1065063476562,
      1045.8411865234375
    ],
    "y_values_80": [
      825.244384765625,
      879.1065063476562,
      1045.8411865234375
    ],
    "timestamp": "2026-04-01 07:27:47",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"H\": 4096, \"V\": 128256, \"mode\": \"forward\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.7.0"
  },
  {
    "kernel_name": "fused_linear_cross_entropy",
    "kernel_provider": "huggingface",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "BT",
    "x_label": "B x T",
    "x_values": [
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      23.265480041503906,
      47.650718688964844,
      96.68087768554688
    ],
    "y_values_20": [
      23.224380493164062,
      47.45820617675781,
      96.68087768554688
    ],
    "y_values_80": [
      23.351131439208984,
      47.84323501586914,
      96.68087768554688
    ],
    "timestamp": "2026-04-01 07:28:11",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"H\": 4096, \"V\": 128256, \"mode\": \"forward\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.7.0"
  },
  {
    "kernel_name": "fused_linear_cross_entropy",
    "kernel_provider": "liger",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "BT",
    "x_label": "B x T",
    "x_values": [
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      5.694960117340088,
      5.778989791870117,
      6.006700038909912
    ],
    "y_values_20": [
      5.692084312438965,
      5.776080131530762,
      6.00248384475708
    ],
    "y_values_80": [
      5.700056076049805,
      5.781300067901611,
      6.012008190155029
    ],
    "timestamp": "2026-04-01 07:28:38",
    "kernel_operation_mode": "backward",
    "extra_benchmark_config_str": "{\"H\": 4096, \"V\": 128256, \"mode\": \"forward\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.7.0"
  },
  {
    "kernel_name": "fused_linear_cross_entropy",
    "kernel_provider": "liger-fp32-accum",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "BT",
    "x_label": "B x T",
    "x_values": [
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      5.6849799156188965,
      5.7753400802612305,
      6.017099857330322
    ],
    "y_values_20": [
      5.681600093841553,
      5.7729997634887695,
      6.013432025909424
    ],
    "y_values_80": [
      5.686679840087891,
      5.778719902038574,
      6.023987770080566
    ],
    "timestamp": "2026-04-01 07:29:04",
    "kernel_operation_mode": "backward",
    "extra_benchmark_config_str": "{\"H\": 4096, \"V\": 128256, \"mode\": \"forward\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.7.0"
  },
  {
    "kernel_name": "fused_linear_cross_entropy",
    "kernel_provider": "huggingface",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "BT",
    "x_label": "B x T",
    "x_values": [
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      74.56654357910156,
      126.84735870361328,
      235.11676025390625
    ],
    "y_values_20": [
      74.56654357910156,
      126.84735870361328,
      235.11676025390625
    ],
    "y_values_80": [
      74.56654357910156,
      126.84735870361328,
      235.11676025390625
    ],
    "timestamp": "2026-04-01 07:29:31",
    "kernel_operation_mode": "backward",
    "extra_benchmark_config_str": "{\"H\": 4096, \"V\": 128256, \"mode\": \"forward\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.7.0"
  },
  {
    "kernel_name": "fused_linear_cross_entropy",
    "kernel_provider": "liger",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "BT",
    "x_label": "B x T",
    "x_values": [
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      925.0958251953125,
      976.5956420898438,
      1144.0975341796875
    ],
    "y_values_20": [
      925.0958251953125,
      976.5956420898438,
      1144.0975341796875
    ],
    "y_values_80": [
      925.0958251953125,
      976.5956420898438,
      1144.0975341796875
    ],
    "timestamp": "2026-04-01 07:30:19",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"H\": 4096, \"V\": 128256, \"mode\": \"forward\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.7.0"
  },
  {
    "kernel_name": "fused_linear_cross_entropy",
    "kernel_provider": "liger-fp32-accum",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "BT",
    "x_label": "B x T",
    "x_values": [
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      828.4945068359375,
      884.9970092773438,
      1052.578125
    ],
    "y_values_20": [
      828.4945068359375,
      884.9970092773438,
      1052.578125
    ],
    "y_values_80": [
      828.4945068359375,
      884.9970092773438,
      1052.578125
    ],
    "timestamp": "2026-04-01 07:31:04",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"H\": 4096, \"V\": 128256, \"mode\": \"forward\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.7.0"
  },
  {
    "kernel_name": "fused_linear_cross_entropy",
    "kernel_provider": "huggingface",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "BT",
    "x_label": "B x T",
    "x_values": [
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      98.43025970458984,
      175.3376007080078,
      330.91064453125
    ],
    "y_values_20": [
      98.43025970458984,
      175.3376007080078,
      330.91064453125
    ],
    "y_values_80": [
      98.43025970458984,
      175.3376007080078,
      330.91064453125
    ],
    "timestamp": "2026-04-01 07:31:32",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"H\": 4096, \"V\": 128256, \"mode\": \"forward\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.7.0"
  },
  {
    "kernel_name": "fused_linear_cross_entropy",
    "kernel_provider": "liger",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "BT",
    "x_label": "B x T",
    "x_values": [
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      917.0685424804688,
      970.683349609375,
      1137.266845703125
    ],
    "y_values_20": [
      917.0685424804688,
      970.683349609375,
      1137.266845703125
    ],
    "y_values_80": [
      917.0685424804688,
      970.683349609375,
      1137.266845703125
    ],
    "timestamp": "2026-04-01 07:32:20",
    "kernel_operation_mode": "no-grad-forward",
    "extra_benchmark_config_str": "{\"H\": 4096, \"V\": 128256, \"mode\": \"forward\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.7.0"
  },
  {
    "kernel_name": "fused_linear_cross_entropy",
    "kernel_provider": "liger-fp32-accum",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "BT",
    "x_label": "B x T",
    "x_values": [
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      825.2996215820312,
      882.8753051757812,
      1045.9044189453125
    ],
    "y_values_20": [
      825.2996215820312,
      882.8753051757812,
      1045.9044189453125
    ],
    "y_values_80": [
      825.2996215820312,
      882.8753051757812,
      1045.9044189453125
    ],
    "timestamp": "2026-04-01 07:33:05",
    "kernel_operation_mode": "no-grad-forward",
    "extra_benchmark_config_str": "{\"H\": 4096, \"V\": 128256, \"mode\": \"forward\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.7.0"
  },
  {
    "kernel_name": "fused_linear_cross_entropy",
    "kernel_provider": "huggingface",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "BT",
    "x_label": "B x T",
    "x_values": [
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      23.153650283813477,
      47.855567932128906,
      94.57154083251953
    ],
    "y_values_20": [
      23.059932708740234,
      47.7861328125,
      94.57154083251953
    ],
    "y_values_80": [
      23.219385147094727,
      47.92500686645508,
      94.57154083251953
    ],
    "timestamp": "2026-04-01 07:33:29",
    "kernel_operation_mode": "no-grad-forward",
    "extra_benchmark_config_str": "{\"H\": 4096, \"V\": 128256, \"mode\": \"forward\", \"dtype\": \"torch.bfloat16\"}",
    "liger_version": "0.7.0"
  }
]


# Here we calculate the gradient of logits_chunk in place so we can save memory.
# Grid size is capped at NPU core count; the kernel uses a grid-stride loop
liger_cross_entropy_kernel[(n_rows,)](
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we set grid size to num_cores?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In liger_cross_entropy_kernel, there is no loop over num_programs() as seen in other kernels, so n_rows is used instead.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, my local branch was outdated. This has been addressed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants